{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c606146e-f98a-476a-ab2e-84b96f7a97fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"\"\"蔚来萤火虫即将交付，李斌所预测的两年内新能源车渗透率超80%能否实现？\n",
    "\n",
    "9 月 6 日消息，近日，蔚来汽车创始人、董事长李斌在蔚来九周年内部讲话及财报电话会上透露了多项重要信息，其中最引人注目的莫过于第三品牌萤火虫将于 2025 年正式交付，并预测新能源汽车的市场渗透率将在未来两年内超过 80%。\n",
    "\n",
    "据李斌介绍，蔚来汽车将形成三个品牌矩阵，覆盖从 14 万元到 80 万元的广阔市场区间。其中，第三品牌萤火虫作为蔚来汽车布局中低端市场的重要棋子，将于 2025 年正式交付。这一举措不仅丰富了蔚来的产品线，也进一步提升了其在新能源汽车市场的竞争力。\n",
    "\n",
    "This chunker works by determining when to \"break\" apart sentences. This is done by looking for differences in embeddings between any two sentences. When that difference is past some threshold, then they are split.\n",
    "\n",
    "There are a few ways to determine what that threshold is, which are controlled by the breakpoint_threshold_type kwarg.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "9160332e-a993-434b-b88b-7d508850b270",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "text_splitter = RecursiveCharacterTextSplitter(\n",
    "        chunk_size=64, chunk_overlap=12, length_function=len\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "b4dab907-58d9-407e-a221-f8652520f642",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_experimental.text_splitter import SemanticChunker\n",
    "from langchain_community.embeddings import HuggingFaceEmbeddings\n",
    "from langchain.vectorstores import FAISS\n",
    "\n",
    "model_name = \"../../bge-small-zh-v1.5\"\n",
    "model_kwargs = {'device': 'cpu'}\n",
    "encode_kwargs = {'normalize_embeddings': True}\n",
    "hf_embedding = HuggingFaceEmbeddings(\n",
    "    model_name=model_name,\n",
    "    model_kwargs=model_kwargs,\n",
    "    encode_kwargs=encode_kwargs\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "f3f43fe3-6f8b-478d-be93-0ff01e04b08d",
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.docstore.document import Document"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "d3a0e66e-9a2f-468e-b487-84d69379d808",
   "metadata": {},
   "outputs": [],
   "source": [
    "texts = text_splitter.split_text(text)\n",
    "doc_texts = [Document(page_content=_,metadata={}) for _ in texts]\n",
    "vectorstore = FAISS.from_documents(doc_texts, hf_embedding)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "9348017c-bb25-45ea-bb65-1f29c77d4f35",
   "metadata": {},
   "outputs": [],
   "source": [
    "from rank_bm25 import BM25Okapi"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "5a68b1e9-d4cf-46f2-8c9a-d5a47f9f25fe",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Building prefix dict from the default dictionary ...\n",
      "Dumping model to file cache C:\\Users\\13494\\AppData\\Local\\Temp\\jieba.cache\n",
      "Loading model cost 1.473 seconds.\n",
      "Prefix dict has been built successfully.\n"
     ]
    }
   ],
   "source": [
    "import jieba\n",
    "bm25 = BM25Okapi([jieba.lcut(text) for text in texts])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "40f5ef41-a022-41ae-81aa-cb5048699428",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 49,
   "id": "d8962de5-a763-4819-b945-ca20069205ea",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 原始的代码使用vectorestore获取all_docs，顺序是有问题的\n",
    "def fusion_retrieval(vectorstore, bm25, query: str, k: int = 5, alpha: float = 0.5):\n",
    "\n",
    "    # Step 1: Get all documents from the vectorstore\n",
    "    # all_docs = vectorstore.similarity_search(\"\", k=vectorstore.index.ntotal)\n",
    "    all_docs = texts\n",
    "    \n",
    "    # Step 2: Perform BM25 search\n",
    "    bm25_scores = bm25.get_scores(jieba.lcut(query))\n",
    "\n",
    "    # Step 3: Perform vector search\n",
    "    vector_results = vectorstore.similarity_search_with_score(query, k=len(all_docs))\n",
    "    # print(vector_results)\n",
    "    pagecontent2score = {i.page_content:j for i,j in vector_results}\n",
    "    # Step 4: Normalize scores\n",
    "    vector_scores = np.array([pagecontent2score.get(i) for i in all_docs])\n",
    "    vector_scores = 1 - (vector_scores - np.min(vector_scores)) / (np.max(vector_scores) - np.min(vector_scores))\n",
    "\n",
    "    bm25_scores = (bm25_scores - np.min(bm25_scores)) / (np.max(bm25_scores) - np.min(bm25_scores))\n",
    "    # Step 5: Combine scores\n",
    "    combined_scores = alpha * vector_scores + (1 - alpha) * bm25_scores  \n",
    "    # Step 6: Rank documents\n",
    "    sorted_indices = np.argsort(combined_scores)[::-1]\n",
    "    \n",
    "    # Step 7: Return top k documents\n",
    "    return [all_docs[i] for i in sorted_indices[:k]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "id": "517c7ec7-9a10-4671-814c-75ff1d471e0c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['日消息，近日，蔚来汽车创始人、董事长李斌在蔚来九周年内部讲话及财报电话会上透露了多项重要信息，其中最引人注目的莫过于第三品牌萤', '据李斌介绍，蔚来汽车将形成三个品牌矩阵，覆盖从 14 万元到 80']\n"
     ]
    }
   ],
   "source": [
    "query = \"蔚来\"\n",
    "\n",
    "top_docs = fusion_retrieval(vectorstore, bm25, query, k=2, alpha=0.5)\n",
    "docs_content = [doc for doc in top_docs]\n",
    "print(docs_content)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "be8fa0a5-1f86-4c74-a1b7-5ef9a5cda4bc",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
